{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "source": [
    "# Introduction\n",
    "*Thomas Dooms*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### What to expect\n",
    "\n",
    "This sequence of tutorials is a practical introduction to weight-based interpretability methods. We focus on methods that decompose bilinear layers into a set of interpretable components. These tutorials accompany [this paper](https://arxiv.org/abs/2410.08417), which overviews the required math and discusses many experiments. Here, we will focus on intuition, code, and generally getting our hands dirty in exploring bilinear networks.\n",
    "\n",
    "By the end of this tutorial set, you should understand the following.\n",
    "- How bilinear layers enable weight-based interpretability (chapter 0).\n",
    "- How to decompose bilinear layers into important components (chapter 1).\n",
    "- How to analyze the weights of MLPs in deep transformer models (chapter 2).\n",
    "- How to fine-tune existing models to make them more interpretable (chapter 3).\n",
    "\n",
    "The tutorials should be seen as well-documented notebooks; they are quite text-heavy and are not fully self-contained. Learning about specific implementation details may require perusing the files that the notebooks import. Furthermore, these tutorials do not contain any exercise-solution pairs. We either cover an already-implemented experiment or discuss some open questions that we think may be worthwhile exploring.\n",
    "\n",
    "We made this because weight-based interpretability can quickly become math-heavy, but the underlying ideas are actually very simple. The aim is to give a step-by-step tutorial so anyone can implement the main experiments to interpret small image models and our pre-trained language models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A history of interpretability methods\n",
    "\n",
    "Any interpretability technique falls on a spectrum between weight-based and activation-based (sometimes called input-based). The former is generally based on finding meaningful (to humans) structures in large-weight matrices. Throughout history, the two techniques have been alternatingly popular.\n",
    "\n",
    "- Before the deep learning explosion around 2010, researchers verified whether a given learning algorithm had converged to a desired solution by manually inspecting the weights. This generally involved small-scale studies.\n",
    "\n",
    "- Then came the era of heatmaps and other attribution methods, which used gradients (for a specific input) to understand how a model makes its classifications. While these methods had their utility, it became apparent that this was insufficient to understand what a model was doing.\n",
    "\n",
    "- A bit later, mechanistic interpretability entered the scene using circuit-based methods. These methods found compositional structure within image models based on a mix of weight and inputs. Later, the field shifted to language models, finding interpretable components in language models that perform a given task.\n",
    "\n",
    "- Now, mechinterp has mostly shifted back towards analyzing activations, albeit on a large scale using SAEs. These aim to find sparsely activating features and can be seen as a form of unsupervised probing (e.g. learn a collection of probes that can reconstruct all activations).\n",
    "\n",
    "The historical shift between the two is mostly due to their respective advantages and disadvantages: weight-based interpretability is hard, and input-based interpretability may miss important behaviors. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Why does input-based interpretability miss behaviors?\n",
    "Input-based interpretability falls short in two aspects, which I'll call universality and detectability.\n",
    "\n",
    "Assessing the universality of a feature (or circuit) using inputs is quite hard. The 'best' approach is to enumerate all input instances that activate this given feature. Understanding how a given factor influences that feature must be checked explicitly, which can quickly lead to exponentially many combinations to try. Generally, there always exists some adversarial attack that will disturb the feature in unexpected ways. Beyond this 'in-distribution' universality, one can imagine 'out-of-distribution' universality. Consider a jailbreak; this can be seen as some input, often not in the training set, that makes a model activate features we do not want. In the next chapter, we show how analyzing the weights corresponding to given features may allow us to expose these behaviors early.\n",
    "\n",
    "Somewhat related is the detectability of a feature, which means finding some specific knowledge or capability that could be hard to elicit. This is closely related to 'enumerative safety', which aims to catalog all possible capabilities. This may not even be possible when considering the whole input dataset. In planning lingo, this can be seen as an unknown-unknown and is seen as a serious threat in alignment literature. One example is the 'treacherous turn', where a model may intentionally hide certain capabilities and will therefore not be found using input-based methods."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Why is weight-based interpretability hard?\n",
    "All contemporary models use some form of activation function (also called non-linearity). Non-polynomial activation functions (such as ReLU, Sigmoid, etc.) allow an arbitrarily large set of inputs to interact. Intuitively, consider a certain set of inputs that activate a ReLU output; ablating any of them could switch off the ReLU. In that sense, they interact non-linearly, where each input is dependent on each other. We call such behavior non-decomposable. we cannot consider subsets of the input separately; this would unrecoverable change the outputs. Hence, studying a ReLU requires full input; otherwise, its behavior doesn't make sense.\n",
    "\n",
    "In summary, activation functions shroud the relation between inputs and outputs of any given layer. It is generally impossible to make statements about which inputs are important without actually computing them. However, many studies ignore or approximate this activation function. Once this is done, one can analyze a given layer with techniques from linear algebra. The disadvantage of doing this is that, due to the approximation, such techniques could produce results that are not in line with what the model is actually doing, with no way to provide (helpful) error bounds.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Advantages of weight-based interpretability.\n",
    "\n",
    "Using the weights as a central object for interpretability has one huge advantage: they are enumerable. One can try to understand all weights but not all inputs (generally). Weight-based decompositions are useful to this end as they allow us to quantify the importance of certain directions (through the singular values). SVD is generally very good at reducing the description length of a given matrix without losing too much accuracy. Towards interpretability, we could consider a handful of dimensions that maximally contribute to the model. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### What are bilinear layers?\n",
    "Bilinear layers are a promising solution to many of the challenges explained above. In short, they obviate the need for the pesky activation function while retaining high accuracy. This last point seems crucial for any interpretable architecture.\n",
    "\n",
    "The intuition/motivation behind bilinear layers can be seen from two angles. The first is that a bilinear layer is a gated linear unit (GLU) without an activation function. This results in a sort of 'continuous' gate where either side can continuously determine the importance of the other side. The second intuition is that we are replacing the point-wise activation function of an ordinary linear unit with some learned activation function that is elementwise multiplied.\n",
    "\n",
    "Importantly, bilinear layers are non-linear in their inputs (otherwise, we'd just have a normal matrix) but are linear in pairwise inputs. Put differently, if we use all pairs of inputs as a basis, this layer is fully linear. This means that pairs of inputs, which we call an interaction, are fundamental to these bilinear layers. Consequently, we can fully describe how inputs interact for a given output. Since everything is linear in this 'interaction space', we can consider any output direction in exactly the same fashion. Hence, the central object in this tutorial is the 'interaction matrix', which encodes how strongly certain inputs interact towards an output."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analyzing bilinear layers with interaction matrices.\n",
    "\n",
    "To get a sense of these matrices, let's consider a toy task of binary computation. The task converts $n$ input features according to a binary operation between all pairs. For instance, if we have 3 inputs and the operation is AND, the outputs should be AND($x_1$, $x_2$), AND($x_1$, $x_3$), AND($x_2$, $x_3$). The model \n",
    "consists of a single MLP: a bilinear layer (up projection + activation) and a head (down projection). You can take a look at the model source code; however, it's mostly implementation details.\n",
    "\n",
    "The following code snippet instantiates such a model with 4 inputs and trains it. The loss curve shows that the model has converged well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from toy import Model\n",
    "from plotly import express as px\n",
    "from einops import einsum\n",
    "\n",
    "# Instantiate the model according to some configuration\n",
    "model = Model.from_config(n_inputs=4, n_hidden=6, n_outputs=6, operation={\"and\": 1})\n",
    "# model = Model.from_config(n_inputs=4, n_hidden=6, n_outputs=6, operation={\"xor\": 1})\n",
    "\n",
    "# Train it and plot the loss\n",
    "loss = model.fit()\n",
    "px.line(y=loss, log_y=True, template=\"plotly_white\", labels=dict(y=\"loss\", x=\"epoch\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we compute the interaction matrix for a given output feature. To do so, we consider one head output direction and combine that with both the left and right sides of the layer. We notice that the resulting matrix doesn't really seem interpretable. The reason for this is somewhat mathematical and can be found in the paper. Intuitively, there are some invariances in these interactions that actually don't contribute but can make interpretability harder. Luckily, we can remove those by symmetrizing the matrix as the commented line does. Afterward, we see that we have a clean interaction pattern."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the interaction matrix for a given output dimension\n",
    "q = einsum(model.w_p[1], model.w_l, model.w_r, \"hid, hid in1, hid in2 -> in1 in2\")\n",
    "\n",
    "# Symmetrize the interaction matrix, uncomment for a clean interaction matrix\n",
    "# .mT is the transpose of the last two dimensions (which does not matter here but will in the future)\n",
    "# q = 0.5 * (q + q.mT)\n",
    "\n",
    "px.imshow(q, color_continuous_midpoint=0, color_continuous_scale=\"RdBu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's quickly cover why we get this interaction pattern: an AND operation should only produce an output when both features are active. Hence, features do not 'self-interact', meaning no output is active if only one is active. However, if both features are active, they 'cross-interact' towards the AND output. We see this pattern is similar across all outputs (only for different input features, as expected). \n",
    "\n",
    "Generally (as seen below), we compute all the interaction matrices at the same time, the result is a third-order tensor $B$ where each 'slice' represents an interaction. This tensor encompasses the full computation of the layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the interaction matrices for all classes (outputs)\n",
    "b = einsum(model.w_p, model.w_l, model.w_r, \"cls out, out in1, out in2 -> cls in1 in2\")\n",
    "\n",
    "# Symmetrize the interaction matrices\n",
    "b = 0.5 * (b + b.mT)\n",
    "\n",
    "px.imshow(b, color_continuous_midpoint=0, color_continuous_scale=\"RdBu\", facet_col=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can play around with some binary operations; try to think about what interaction they will produce beforehand. Also, some of these will produce strange interaction matrices. This happens because we are not using biases. Luckily, biases can be included in the visualization, but it is kept simple here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Each interaction is meaningful by itself\n",
    "Unlike weights in an ordinary MLP, the interactions we study are meaningful by themselves. As all interactions will be summed together at the end, there is no strange inter-dependence between these elements. This decomposability is extremely handy for interpretability. For instance, if we find an AND pattern between two features, we can state that they will always interact like that, no matter the other parts of the input. This property is something we will leverage quite profoundly in the following chapters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Future directions\n",
    "\n",
    "Beyond simple computational interaction matrices, one may also consider how a model compresses sparse features.\n",
    "From the experiments we have performed, bilinear layers have different behaviors in highly bottlenecked regimes. \n",
    "For bipodal superposition, they perform slightly worse than ReLU, but they perform better for complex geometries.\n",
    "\n",
    "Related works:\n",
    "- [Toy Models of Superposition](https://transformer-circuits.pub/2022/toy_model/index.html)\n",
    "- [Polysemanticity and Capacity in Neural Networks](https://arxiv.org/abs/2210.01892)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
